Compare commits

...

6 Commits

Author SHA1 Message Date
Awni Hannun
c763fe1be0 default strict mode for module update and update_modules (#2239) 2025-06-05 15:27:02 -07:00
Cheng
52dc8c8cd5 Add profiler annotations in common primitives for CUDA backend (#2244) 2025-06-04 19:55:12 -07:00
Angelos Katharopoulos
aede70e81d Perf regression fix (#2243) 2025-06-03 17:55:12 -07:00
Cheng
85a8beb5e4 Avoid atomic updates across CPU/GPU in CUDA event (#2231) 2025-06-03 16:49:06 -07:00
Cheng
0bb89e9e5f Share more common code in Compiled (#2240)
* Share more common code in Compiled

* Remove build_lib_name
2025-06-03 16:48:50 -07:00
Cheng
5685ceb3c7 Avoid invoking allocator::malloc when creating CUDA event (#2232) 2025-06-03 16:48:40 -07:00
17 changed files with 352 additions and 347 deletions

View File

@@ -1,8 +1,7 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include "mlx/backend/common/compiled.h" #include "mlx/backend/common/compiled.h"
#include "mlx/graph_utils.h" #include "mlx/backend/common/utils.h"
#include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -79,55 +78,6 @@ std::string get_type_string(Dtype d) {
} }
} }
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids) {
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (auto& x : inputs) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (auto& a : tape) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
return os.str();
}
bool compiled_check_contiguity( bool compiled_check_contiguity(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const Shape& shape) { const Shape& shape) {
@@ -159,8 +109,7 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs( void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::function<bool(size_t)>& is_constant,
const std::unordered_set<uintptr_t>& constant_ids_,
bool contiguous) { bool contiguous) {
if (contiguous) { if (contiguous) {
int o = 0; int o = 0;
@@ -175,8 +124,7 @@ void compiled_allocate_outputs(
// - Donatable // - Donatable
// - Not a constant // - Not a constant
if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) && if (in.itemsize() == outputs[o].itemsize() && !is_scalar(in) &&
in.is_donatable() && in.is_donatable() && is_constant(i)) {
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) {
outputs[o++].copy_shared_buffer(in); outputs[o++].copy_shared_buffer(in);
} }
// Get representative input flags to properly set non-donated outputs // Get representative input flags to properly set non-donated outputs
@@ -204,7 +152,7 @@ void compiled_allocate_outputs(
// - Not a constant // - Not a constant
if (in.flags().row_contiguous && in.size() == outputs[o].size() && if (in.flags().row_contiguous && in.size() == outputs[o].size() &&
in.itemsize() == outputs[o].itemsize() && in.is_donatable() && in.itemsize() == outputs[o].itemsize() && in.is_donatable() &&
constant_ids_.find(inputs_[i].id()) == constant_ids_.end()) { is_constant(i)) {
outputs[o].copy_shared_buffer( outputs[o].copy_shared_buffer(
in, outputs[o].strides(), in.flags(), in.data_size()); in, outputs[o].strides(), in.flags(), in.data_size());
o++; o++;
@@ -216,4 +164,74 @@ void compiled_allocate_outputs(
} }
} }
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant) {
const Shape& shape = out.shape();
bool contiguous = compiled_check_contiguity(inputs, shape);
if (contiguous) {
return {true, shape, {}};
}
std::vector<Strides> strides_vec{out.strides()};
for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants.
if (is_constant(i)) {
continue;
}
// Skip scalar inputs.
const auto& x = inputs[i];
if (is_scalar(x)) {
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
size_t j = 0;
for (; j < shape.size() - x.ndim(); ++j) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (size_t i = 0; i < x.ndim(); ++i, ++j) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(out.strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides_vec.push_back(std::move(xstrides));
}
auto tup = collapse_contiguous_dims(shape, strides_vec, INT32_MAX);
return {false, std::move(std::get<0>(tup)), std::move(std::get<1>(tup))};
}
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous) {
if (contiguous) {
size_t max_size = 0;
for (const auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
return max_size > UINT32_MAX;
} else {
size_t max_size = 0;
for (const auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
return max_size > UINT32_MAX;
}
}
} // namespace mlx::core } // namespace mlx::core

View File

@@ -1,9 +1,8 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#pragma once #pragma once
#include <functional>
#include <iomanip> #include <iomanip>
#include <sstream>
#include <unordered_set>
#include "mlx/array.h" #include "mlx/array.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
@@ -14,12 +13,6 @@ inline bool is_static_cast(const Primitive& p) {
return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType)); return (typeid(p) == typeid(Broadcast) || typeid(p) == typeid(AsType));
} }
std::string build_lib_name(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids);
std::string get_type_string(Dtype d); std::string get_type_string(Dtype d);
template <typename T> template <typename T>
@@ -60,8 +53,19 @@ bool compiled_check_contiguity(
void compiled_allocate_outputs( void compiled_allocate_outputs(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs, std::vector<array>& outputs,
const std::vector<array>& inputs_, const std::function<bool(size_t)>& is_constant,
const std::unordered_set<uintptr_t>& constant_ids_, bool contiguous);
// Collapse contiguous dims ignoring scalars and constants.
std::tuple<bool, Shape, std::vector<Strides>> compiled_collapse_contiguous_dims(
const std::vector<array>& inputs,
const array& out,
const std::function<bool(size_t)>& is_constant);
// Return whether the kernel should use large index.
bool compiled_use_large_index(
const std::vector<array>& inputs,
const std::vector<array>& outputs,
bool contiguous); bool contiguous);
} // namespace mlx::core } // namespace mlx::core

View File

@@ -146,18 +146,9 @@ inline void build_kernel(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape, const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids, const std::function<bool(size_t)>& is_constant,
bool contiguous, bool contiguous,
int ndim) { int ndim) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer; NodeNamer namer;
#ifdef _MSC_VER #ifdef _MSC_VER
@@ -170,14 +161,15 @@ inline void build_kernel(
// Add the input arguments // Add the input arguments
int cnt = 0; int cnt = 0;
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
auto& xname = namer.get_name(x);
// Skip constants from the input list // Skip constants from the input list
if (is_constant(x)) { if (is_constant(i)) {
continue; continue;
} }
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
auto tstr = get_type_string(x.dtype()); auto tstr = get_type_string(x.dtype());
os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++ os << " " << tstr << "* " << xname << " = (" << tstr << "*)args[" << cnt++
<< "];" << std::endl; << "];" << std::endl;
@@ -211,10 +203,11 @@ inline void build_kernel(
} }
// Read the inputs in tmps // Read the inputs in tmps
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
const auto& x = inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (is_constant(x)) { if (is_constant(i)) {
os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = "; os << " " << get_type_string(x.dtype()) << " tmp_" << xname << " = ";
print_constant(os, x); print_constant(os, x);
os << ";" << std::endl; os << ";" << std::endl;
@@ -264,8 +257,9 @@ inline void build_kernel(
} else { } else {
for (int d = ndim - 1; d >= 0; --d) { for (int d = ndim - 1; d >= 0; --d) {
// Update pointers // Update pointers
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
if (is_constant(x) || is_scalar(x)) { const auto& x = inputs[i];
if (is_constant(i) || is_scalar(x)) {
continue; continue;
} }
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
@@ -287,65 +281,37 @@ inline void build_kernel(
void Compiled::eval_cpu( void Compiled::eval_cpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Figure out which kernel we are using
auto& shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, shape);
auto& encoder = cpu::get_command_encoder(stream()); auto& encoder = cpu::get_command_encoder(stream());
// Handle all broadcasting and collect function input arguments // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting.
auto [contiguous, shape, strides] =
compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
// Collect function input arguments.
std::vector<void*> args; std::vector<void*> args;
std::vector<std::vector<size_t>> strides; int strides_index = 1;
for (int i = 0; i < inputs.size(); i++) { for (size_t i = 0; i < inputs.size(); ++i) {
// Skip constants. if (is_constant_(i)) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue; continue;
} }
auto& x = inputs[i]; const auto& x = inputs[i];
encoder.set_input_array(x); encoder.set_input_array(x);
args.push_back((void*)x.data<void>()); args.push_back((void*)x.data<void>());
if (!contiguous && !is_scalar(x)) {
if (contiguous || is_scalar(x)) { args.push_back(strides[strides_index++].data());
continue;
} }
// Broadcast the input to the output shape.
std::vector<size_t> xstrides;
int j = 0;
for (; j < shape.size() - x.ndim(); j++) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
strides.push_back(std::move(xstrides));
args.push_back(strides.back().data());
} }
// Get the kernel name from the lib // Get the kernel name from the lib
int ndim = shape.size(); int ndim = shape.size();
auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_"); auto kernel_name = kernel_lib_ + (contiguous ? "_contiguous" : "_strided_");
if (!contiguous) { if (!contiguous) {
kernel_name += std::to_string(shape.size()); kernel_name += std::to_string(ndim);
} }
// Get the function // Get the function
auto fn_ptr = compile(kernel_name, [&]() { auto fn_ptr = compile(kernel_name, [&, contiguous = contiguous]() {
std::ostringstream kernel; std::ostringstream kernel;
kernel << get_kernel_preamble() << std::endl; kernel << get_kernel_preamble() << std::endl;
kernel << "extern \"C\" {" << std::endl; kernel << "extern \"C\" {" << std::endl;
@@ -355,7 +321,7 @@ void Compiled::eval_cpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
contiguous, contiguous,
ndim); ndim);
// Close extern "C" // Close extern "C"
@@ -363,26 +329,22 @@ void Compiled::eval_cpu(
return kernel.str(); return kernel.str();
}); });
compiled_allocate_outputs( compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
inputs, outputs, inputs_, constant_ids_, contiguous);
for (auto& x : outputs) { for (auto& x : outputs) {
args.push_back(x.data<void>()); args.push_back(x.data<void>());
encoder.set_output_array(x); encoder.set_output_array(x);
} }
Shape out_shape;
if (!contiguous) { if (!contiguous) {
out_shape = outputs[0].shape(); args.push_back((void*)shape.data());
args.push_back((void*)out_shape.data());
} else { } else {
args.push_back((void*)outputs[0].data_size()); args.push_back((void*)outputs[0].data_size());
} }
auto fun = (void (*)(void**))fn_ptr; auto fun = (void (*)(void**))fn_ptr;
encoder.dispatch( encoder.dispatch([fun,
[fun,
args = std::move(args), args = std::move(args),
strides = std::move(strides), strides = std::move(strides),
out_shape = std::move(out_shape)]() mutable { fun(args.data()); }); shape = std::move(shape)]() mutable { fun(args.data()); });
} }
} // namespace mlx::core } // namespace mlx::core

View File

@@ -10,13 +10,15 @@ target_sources(
${CMAKE_CURRENT_SOURCE_DIR}/device.cpp ${CMAKE_CURRENT_SOURCE_DIR}/device.cpp
${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp ${CMAKE_CURRENT_SOURCE_DIR}/eval.cpp
${CMAKE_CURRENT_SOURCE_DIR}/event.cu ${CMAKE_CURRENT_SOURCE_DIR}/event.cu
${CMAKE_CURRENT_SOURCE_DIR}/fence.cu ${CMAKE_CURRENT_SOURCE_DIR}/fence.cpp
${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu ${CMAKE_CURRENT_SOURCE_DIR}/kernel_utils.cu
${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu ${CMAKE_CURRENT_SOURCE_DIR}/primitives.cu
${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp ${CMAKE_CURRENT_SOURCE_DIR}/slicing.cpp
${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp ${CMAKE_CURRENT_SOURCE_DIR}/utils.cpp
${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp) ${CMAKE_CURRENT_SOURCE_DIR}/worker.cpp)
target_compile_definitions(mlx PRIVATE MLX_USE_CUDA)
# Enable defining device lambda functions. # Enable defining device lambda functions.
target_compile_options(mlx target_compile_options(mlx
PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>") PRIVATE "$<$<COMPILE_LANGUAGE:CUDA>:--extended-lambda>")

View File

@@ -18,7 +18,10 @@ CudaAllocator::CudaAllocator()
: buffer_cache_( : buffer_cache_(
getpagesize(), getpagesize(),
[](CudaBuffer* buf) { return buf->size; }, [](CudaBuffer* buf) { return buf->size; },
[this](CudaBuffer* buf) { cuda_free(buf); }) { [this](CudaBuffer* buf) {
cuda_free(buf->data);
delete buf;
}) {
// TODO: Set memory limit for multi-device. // TODO: Set memory limit for multi-device.
size_t free, total; size_t free, total;
CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total)); CHECK_CUDA_ERROR(cudaMemGetInfo(&free, &total));
@@ -70,7 +73,8 @@ void CudaAllocator::free(Buffer buffer) {
buffer_cache_.recycle_to_cache(buf); buffer_cache_.recycle_to_cache(buf);
} else { } else {
lock.unlock(); lock.unlock();
cuda_free(buf); cuda_free(buf->data);
delete buf;
} }
} }
@@ -87,6 +91,25 @@ void CudaAllocator::register_this_thread() {
allowed_threads_.insert(std::this_thread::get_id()); allowed_threads_.insert(std::this_thread::get_id());
} }
void CudaAllocator::cuda_free(void* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf);
}
size_t CudaAllocator::get_active_memory() const { size_t CudaAllocator::get_active_memory() const {
return active_memory_; return active_memory_;
} }
@@ -125,26 +148,6 @@ void CudaAllocator::clear_cache() {
buffer_cache_.clear(); buffer_cache_.clear();
} }
void CudaAllocator::cuda_free(CudaBuffer* buf) {
// If cuda_free() is called from a unregistered thread, reschedule the call to
// worker.
{
std::lock_guard lock(worker_mutex_);
if (allowed_threads_.count(std::this_thread::get_id()) == 0) {
if (!worker_) {
worker_.reset(new Worker);
}
worker_->add_task([this, buf]() { this->cuda_free(buf); });
worker_->end_batch();
worker_->commit();
return;
}
}
cudaFree(buf->data);
delete buf;
}
CudaAllocator& allocator() { CudaAllocator& allocator() {
// By creating the |allocator_| on heap, the destructor of CudaAllocator // By creating the |allocator_| on heap, the destructor of CudaAllocator
// will not be called on exit and buffers in the cache will be leaked. This // will not be called on exit and buffers in the cache will be leaked. This

View File

@@ -34,6 +34,9 @@ class CudaAllocator : public allocator::Allocator {
// buffers there would result in dead lock. // buffers there would result in dead lock.
void register_this_thread(); void register_this_thread();
// Call cudaFree in the safe thread.
void cuda_free(void* buf);
size_t get_active_memory() const; size_t get_active_memory() const;
size_t get_peak_memory() const; size_t get_peak_memory() const;
void reset_peak_memory(); void reset_peak_memory();
@@ -47,8 +50,6 @@ class CudaAllocator : public allocator::Allocator {
CudaAllocator(); CudaAllocator();
friend CudaAllocator& allocator(); friend CudaAllocator& allocator();
void cuda_free(CudaBuffer* buf);
std::mutex worker_mutex_; std::mutex worker_mutex_;
std::unique_ptr<Worker> worker_; std::unique_ptr<Worker> worker_;
std::set<std::thread::id> allowed_threads_; std::set<std::thread::id> allowed_threads_;

View File

@@ -1,5 +1,6 @@
// Copyright © 2024 Apple Inc. // Copyright © 2024 Apple Inc.
#include "mlx/backend/cuda/allocator.h"
#include "mlx/backend/cuda/device.h" #include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h" #include "mlx/backend/cuda/event.h"
#include "mlx/backend/cuda/utils.h" #include "mlx/backend/cuda/utils.h"
@@ -111,12 +112,12 @@ __global__ void event_signal_kernel(SharedEvent::Atomic* ac, uint64_t value) {
SharedEvent::SharedEvent() { SharedEvent::SharedEvent() {
// Allocate cuda::atomic on managed memory. // Allocate cuda::atomic on managed memory.
allocator::Buffer buffer = allocator::malloc(sizeof(Atomic)); Atomic* ac;
Atomic* ac = static_cast<Atomic*>(buffer.raw_ptr()); CHECK_CUDA_ERROR(cudaMallocManaged(&ac, sizeof(Atomic)));
new (ac) Atomic(0); new (ac) Atomic(0);
ac_ = std::shared_ptr<Atomic>(ac, [buffer](Atomic* ptr) { ac_ = std::shared_ptr<Atomic>(ac, [](Atomic* ptr) {
ptr->~Atomic(); ptr->~Atomic();
allocator::free(buffer); allocator().cuda_free(ptr);
}); });
} }
@@ -155,7 +156,10 @@ void SharedEvent::signal(cudaStream_t stream, uint64_t value) {
void SharedEvent::signal(Stream s, uint64_t value) { void SharedEvent::signal(Stream s, uint64_t value) {
nvtx3::scoped_range r("cu::SharedEvent::signal(s)"); nvtx3::scoped_range r("cu::SharedEvent::signal(s)");
if (s.device == mlx::core::Device::cpu) { if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [*this, value]() mutable { signal(value); }); // Signal through a GPU stream so the atomic is updated in GPU - updating
// the atomic in CPU sometimes does not get GPU notified.
static CudaStream stream(device(mlx::core::Device::gpu));
scheduler::enqueue(s, [*this, value]() mutable { signal(stream, value); });
} else { } else {
auto& encoder = get_command_encoder(s); auto& encoder = get_command_encoder(s);
encoder.launch_kernel( encoder.launch_kernel(

View File

@@ -0,0 +1,29 @@
// Copyright © 2025 Apple Inc.
#include "mlx/fence.h"
#include "mlx/backend/cuda/event.h"
namespace mlx::core {
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->event.wait(fence->count);
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@@ -1,70 +0,0 @@
// Copyright © 2025 Apple Inc.
#include "mlx/backend/cuda/device.h"
#include "mlx/backend/cuda/event.h"
#include "mlx/fence.h"
#include "mlx/scheduler.h"
#include <nvtx3/nvtx3.hpp>
namespace mlx::core {
namespace {
__host__ __device__ void busy_wait(cuda::atomic<uint64_t>* ac, uint64_t value) {
while (true) {
// In theory the atomic_thread_fence is not needed, but for CUDA 11 without
// it the load() may never return new value.
cuda::atomic_thread_fence(cuda::memory_order_seq_cst);
uint64_t current = ac->load();
if (current >= value) {
break;
}
}
}
__global__ void busy_wait_kernel(cuda::atomic<uint64_t>* ac, uint64_t value) {
busy_wait(ac, value);
}
} // namespace
struct FenceImpl {
uint32_t count;
cu::SharedEvent event;
};
Fence::Fence(Stream s) {
fence_ = std::shared_ptr<void>(
new FenceImpl{0}, [](void* ptr) { delete static_cast<FenceImpl*>(ptr); });
}
void Fence::wait(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
// We can't use SharedEvent::wait because it could hang in CUDA 11, see also:
// https://github.com/ml-explore/mlx/issues/2137
const auto& ac = fence->event.atomic();
if (s.device == mlx::core::Device::cpu) {
scheduler::enqueue(s, [ac, count = fence->count]() {
nvtx3::scoped_range r("Fence::wait()");
busy_wait(ac.get(), count);
});
} else {
nvtx3::scoped_range r("Fence::wait(s)");
auto& encoder = cu::get_command_encoder(s);
encoder.launch_kernel(
encoder.stream().last_cuda_stream(), [&](cudaStream_t stream) {
busy_wait_kernel<<<1, 1, 0>>>(ac.get(), fence->count);
});
encoder.add_completed_handler([ac]() {});
encoder.end_encoding();
}
}
void Fence::update(Stream s, const array&) {
auto* fence = static_cast<FenceImpl*>(fence_.get());
fence->count++;
fence->event.signal(s, fence->count);
}
} // namespace mlx::core

View File

@@ -5,9 +5,17 @@
#include "mlx/backend/gpu/copy.h" #include "mlx/backend/gpu/copy.h"
#include "mlx/backend/gpu/slicing.h" #include "mlx/backend/gpu/slicing.h"
#if defined(MLX_USE_CUDA)
#include <nvtx3/nvtx3.hpp>
#endif
#include <cassert> #include <cassert>
#if defined(MLX_USE_CUDA)
#define MLX_PROFILER_RANGE(message) nvtx3::scoped_range r(message)
#else
#define MLX_PROFILER_RANGE(message) #define MLX_PROFILER_RANGE(message)
#endif
namespace mlx::core { namespace mlx::core {

View File

@@ -11,8 +11,6 @@
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/utils.h" #include "mlx/utils.h"
using namespace fmt::literals;
namespace mlx::core { namespace mlx::core {
inline void build_kernel( inline void build_kernel(
@@ -21,21 +19,12 @@ inline void build_kernel(
const std::vector<array>& inputs, const std::vector<array>& inputs,
const std::vector<array>& outputs, const std::vector<array>& outputs,
const std::vector<array>& tape, const std::vector<array>& tape,
const std::unordered_set<uintptr_t>& constant_ids, const std::function<bool(size_t)>& is_constant,
bool contiguous, bool contiguous,
int ndim, int ndim,
bool dynamic_dims, bool dynamic_dims,
bool use_big_index = false, bool use_big_index = false,
int work_per_thread = 1) { int work_per_thread = 1) {
// All outputs should have the exact same shape and will be row contiguous
auto output_shape = outputs[0].shape();
auto output_strides = outputs[0].strides();
// Constants are scalars that are captured by value and cannot change
auto is_constant = [&constant_ids](const array& x) {
return constant_ids.find(x.id()) != constant_ids.end();
};
NodeNamer namer; NodeNamer namer;
bool add_indices = false; bool add_indices = false;
int cnt = 0; int cnt = 0;
@@ -45,14 +34,15 @@ inline void build_kernel(
"[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name); "[[host_name(\"{0}\")]]\n[[kernel]] void {0}(\n", kernel_name);
// Add the input arguments // Add the input arguments
for (auto& x : inputs) { for (size_t i = 0; i < inputs.size(); ++i) {
auto& xname = namer.get_name(x);
// Skip constants from the input list // Skip constants from the input list
if (is_constant(x)) { if (is_constant(i)) {
continue; continue;
} }
const auto& x = inputs[i];
auto& xname = namer.get_name(x);
// Scalars and contiguous need no strides // Scalars and contiguous need no strides
if (!is_scalar(x) && !contiguous) { if (!is_scalar(x) && !contiguous) {
add_indices = true; add_indices = true;
@@ -80,8 +70,6 @@ inline void build_kernel(
} }
// Add output strides and shape to extract the indices. // Add output strides and shape to extract the indices.
if (!contiguous) { if (!contiguous) {
os += fmt::format(
" constant const int64_t* output_strides [[buffer({0})]],\n", cnt++);
os += fmt::format( os += fmt::format(
" constant const int* output_shape [[buffer({0})]],\n", cnt++); " constant const int* output_shape [[buffer({0})]],\n", cnt++);
} else { } else {
@@ -125,7 +113,7 @@ inline void build_kernel(
auto& x = inputs[i]; auto& x = inputs[i];
auto& xname = namer.get_name(x); auto& xname = namer.get_name(x);
if (is_constant(x)) { if (is_constant(i)) {
auto type_str = get_type_string(x.dtype()); auto type_str = get_type_string(x.dtype());
std::ostringstream ss; std::ostringstream ss;
print_constant(ss, x); print_constant(ss, x);
@@ -271,11 +259,6 @@ inline void build_kernel(
void Compiled::eval_gpu( void Compiled::eval_gpu(
const std::vector<array>& inputs, const std::vector<array>& inputs,
std::vector<array>& outputs) { std::vector<array>& outputs) {
// Make the name for the kernel library
if (kernel_lib_.empty()) {
kernel_lib_ = build_lib_name(inputs_, outputs_, tape_, constant_ids_);
}
// Get the kernel if someone else built it already // Get the kernel if someone else built it already
auto& s = stream(); auto& s = stream();
auto& d = metal::device(s.device); auto& d = metal::device(s.device);
@@ -290,7 +273,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ true, /* contiguous = */ true,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
@@ -302,7 +285,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ true, /* contiguous = */ true,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
@@ -315,7 +298,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ i, /* ndim = */ i,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
@@ -328,7 +311,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ i, /* ndim = */ i,
/* dynamic_dims = */ false, /* dynamic_dims = */ false,
@@ -342,7 +325,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ true, /* dynamic_dims = */ true,
@@ -354,7 +337,7 @@ void Compiled::eval_gpu(
inputs_, inputs_,
outputs_, outputs_,
tape_, tape_,
constant_ids_, is_constant_,
/* contiguous = */ false, /* contiguous = */ false,
/* ndim = */ 0, /* ndim = */ 0,
/* dynamic_dims = */ true, /* dynamic_dims = */ true,
@@ -363,70 +346,13 @@ void Compiled::eval_gpu(
return kernel; return kernel;
}); });
// Figure out which kernel we are using
auto& output_shape = outputs[0].shape();
auto contiguous = compiled_check_contiguity(inputs, output_shape);
// Collapse contiguous dims to route to a faster kernel if possible. Also // Collapse contiguous dims to route to a faster kernel if possible. Also
// handle all broadcasting. // handle all broadcasting.
std::vector<Strides> initial_strides; auto [contiguous, shape, strides] =
initial_strides.push_back(outputs[0].strides()); compiled_collapse_contiguous_dims(inputs, outputs[0], is_constant_);
Shape shape;
std::vector<Strides> strides;
if (!contiguous) {
for (int i = 0; i < inputs.size(); i++) {
// Skip constants.
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) {
continue;
}
auto& x = inputs[i];
// Skip scalar inputs. // Whether to use large index.
if (is_scalar(x)) { bool large = compiled_use_large_index(inputs, outputs, contiguous);
continue;
}
// Broadcast the inputs to the output shape.
Strides xstrides;
int j = 0;
for (; j < output_shape.size() - x.ndim(); j++) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
}
for (int i = 0; i < x.ndim(); i++, j++) {
if (x.shape(i) == 1) {
if (output_shape[j] == 1) {
xstrides.push_back(outputs[0].strides()[j]);
} else {
xstrides.push_back(0);
}
} else {
xstrides.push_back(x.strides()[i]);
}
}
initial_strides.push_back(std::move(xstrides));
}
std::tie(shape, strides) =
collapse_contiguous_dims(output_shape, initial_strides, INT32_MAX);
}
bool large;
if (contiguous) {
size_t max_size = 0;
for (auto& in : inputs) {
max_size = std::max(max_size, in.data_size());
}
large = (max_size > UINT32_MAX);
} else {
size_t max_size = 0;
for (auto& o : outputs) {
max_size = std::max(max_size, o.size());
}
large = (max_size > UINT32_MAX);
}
// Get the kernel from the lib // Get the kernel from the lib
int ndim = shape.size(); int ndim = shape.size();
@@ -451,7 +377,7 @@ void Compiled::eval_gpu(
int stride_idx = 1; // idx 0 is the output strides int stride_idx = 1; // idx 0 is the output strides
Strides in_strides; Strides in_strides;
for (int i = 0; i < inputs.size(); i++) { for (int i = 0; i < inputs.size(); i++) {
if (constant_ids_.find(inputs_[i].id()) != constant_ids_.end()) { if (is_constant_(i)) {
continue; continue;
} }
auto& x = inputs[i]; auto& x = inputs[i];
@@ -468,8 +394,7 @@ void Compiled::eval_gpu(
compute_encoder.set_vector_bytes(in_strides, cnt++); compute_encoder.set_vector_bytes(in_strides, cnt++);
} }
compiled_allocate_outputs( compiled_allocate_outputs(inputs, outputs, is_constant_, contiguous);
inputs, outputs, inputs_, constant_ids_, contiguous);
// Put the outputs in // Put the outputs in
for (auto& x : outputs) { for (auto& x : outputs) {
@@ -478,7 +403,6 @@ void Compiled::eval_gpu(
// Put the output shape and strides in // Put the output shape and strides in
if (!contiguous) { if (!contiguous) {
compute_encoder.set_vector_bytes(strides[0], cnt++);
compute_encoder.set_vector_bytes(shape, cnt++); compute_encoder.set_vector_bytes(shape, cnt++);
} else { } else {
auto size = outputs[0].data_size(); auto size = outputs[0].data_size();

View File

@@ -1,16 +1,20 @@
// Copyright © 2023-2024 Apple Inc. // Copyright © 2023-2024 Apple Inc.
#include <cstdlib> #include <cstdlib>
#include <map> #include <map>
#include <sstream>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include "mlx/allocator.h" #include "mlx/allocator.h"
#include "mlx/backend/common/compiled.h"
#include "mlx/compile.h" #include "mlx/compile.h"
#include "mlx/compile_impl.h" #include "mlx/compile_impl.h"
#include "mlx/fast_primitives.h" #include "mlx/fast_primitives.h"
#include "mlx/graph_utils.h"
#include "mlx/primitives.h" #include "mlx/primitives.h"
#include "mlx/transforms.h" #include "mlx/transforms.h"
#include "mlx/transforms_impl.h" #include "mlx/transforms_impl.h"
#include "mlx/utils.h"
namespace mlx::core { namespace mlx::core {
@@ -82,7 +86,54 @@ Compiled::Compiled(
inputs_(std::move(inputs)), inputs_(std::move(inputs)),
outputs_(std::move(outputs)), outputs_(std::move(outputs)),
tape_(std::move(tape)), tape_(std::move(tape)),
constant_ids_(std::move(constant_ids)) {} constant_ids_(std::move(constant_ids)),
is_constant_([this](size_t i) {
return constant_ids_.find(inputs_[i].id()) != constant_ids_.end();
}) {
// Build the kernel name.
NodeNamer namer;
std::ostringstream os;
std::ostringstream constant_hasher;
// Fill the input names. This is not really necessary, I just like having A,
// B, C, ... as the inputs.
for (const auto& x : inputs_) {
namer.get_name(x);
}
// The primitives describing the tape. For unary and binary primitives this
// must be enough to describe the full computation.
for (const auto& a : tape_) {
// name and type of output
os << namer.get_name(a) << kindof(a.dtype()) << a.itemsize();
// computation performed
a.primitive().print(os);
// name of inputs to the function
for (auto& inp : a.inputs()) {
os << namer.get_name(inp);
}
}
os << "_";
for (const auto& x : inputs_) {
if (constant_ids_.find(x.id()) != constant_ids_.end()) {
os << "C";
print_constant(constant_hasher, x);
} else {
os << (is_scalar(x) ? "S" : "V");
}
}
os << "_";
for (const auto& x : inputs) {
if (constant_ids.find(x.id()) != constant_ids.end()) {
continue;
}
os << kindof(x.dtype()) << x.itemsize();
}
os << "_" << std::hash<std::string>{}(constant_hasher.str());
kernel_lib_ = os.str();
}
std::vector<array> Compiled::vjp( std::vector<array> Compiled::vjp(
const std::vector<array>&, const std::vector<array>&,

View File

@@ -627,6 +627,7 @@ class Compiled : public Primitive {
const std::vector<array> outputs_; const std::vector<array> outputs_;
const std::vector<array> tape_; const std::vector<array> tape_;
const std::unordered_set<uintptr_t> constant_ids_; const std::unordered_set<uintptr_t> constant_ids_;
const std::function<bool(size_t)> is_constant_;
std::string kernel_lib_; std::string kernel_lib_;
}; };

View File

@@ -208,7 +208,9 @@ array eval_impl(std::vector<array> outputs, bool async) {
// output arrays stream // output arrays stream
fences[it->second].wait(stream, in); fences[it->second].wait(stream, in);
} else if (in.event().valid()) { } else if (in.event().valid()) {
if (in.event().stream() != stream) { if (in.event().is_signaled()) {
in.detach_event();
} else if (in.event().stream() != stream) {
// Use event to wait across async eval // Use event to wait across async eval
in.event().wait(stream); in.event().wait(stream);
} }

View File

@@ -4,7 +4,7 @@
#define MLX_VERSION_MAJOR 0 #define MLX_VERSION_MAJOR 0
#define MLX_VERSION_MINOR 26 #define MLX_VERSION_MINOR 26
#define MLX_VERSION_PATCH 0 #define MLX_VERSION_PATCH 1
#define MLX_VERSION_NUMERIC \ #define MLX_VERSION_NUMERIC \
(100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH) (100000 * MLX_VERSION_MAJOR + 1000 * MLX_VERSION_MINOR + MLX_VERSION_PATCH)

View File

@@ -193,7 +193,7 @@ class Module(dict):
) )
if len(weights) != 0: if len(weights) != 0:
self.update(tree_unflatten(weights)) self.update(tree_unflatten(weights), strict=False)
return self return self
def save_weights(self, file: str): def save_weights(self, file: str):
@@ -291,7 +291,7 @@ class Module(dict):
return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module) return self.filter_and_map(self.valid_child_filter, is_leaf_fn=_is_leaf_module)
def update(self, parameters: dict) -> Module: def update(self, parameters: dict, strict: bool = True) -> Module:
"""Replace the parameters of this Module with the provided ones in the """Replace the parameters of this Module with the provided ones in the
dict of dicts and lists. dict of dicts and lists.
@@ -306,6 +306,8 @@ class Module(dict):
Args: Args:
parameters (dict): A complete or partial dictionary of the modules parameters (dict): A complete or partial dictionary of the modules
parameters. parameters.
strict (bool): If ``True`` checks that ``parameters`` is a
subset of the module's parameters. Default: ``True``.
Returns: Returns:
The module instance after updating the parameters. The module instance after updating the parameters.
""" """
@@ -317,21 +319,29 @@ class Module(dict):
current_value = dst[k] current_value = dst[k]
new_value = parameters[k] new_value = parameters[k]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, Module): else:
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(f'Module does not have parameter named "{k}".')
elif isinstance(parameters, list): elif isinstance(parameters, list):
for i in range(len(parameters)): for i in range(len(parameters)):
current_value = dst[i] current_value = dst[i]
new_value = parameters[i] new_value = parameters[i]
if isinstance(current_value, mx.array): if isinstance(current_value, mx.array):
if strict and not isinstance(new_value, mx.array):
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, Module): else:
current_value.update(new_value)
elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(f"Received invalid type: {type(parameters).__name__}.")
apply(self, parameters) apply(self, parameters)
return self return self
@@ -359,7 +369,7 @@ class Module(dict):
self.update(self.filter_and_map(filter_fn, map_fn)) self.update(self.filter_and_map(filter_fn, map_fn))
return self return self
def update_modules(self, modules: dict) -> Module: def update_modules(self, modules: dict, strict: bool = True) -> Module:
"""Replace the child modules of this :class:`Module` instance with the """Replace the child modules of this :class:`Module` instance with the
provided ones in the dict of dicts and lists. provided ones in the dict of dicts and lists.
@@ -368,12 +378,14 @@ class Module(dict):
programmatically swapping layers. programmatically swapping layers.
The passed in parameters dictionary need not be a full dictionary The passed in parameters dictionary need not be a full dictionary
similar to :meth:`parameters`. Only the provided locations will be similar to :meth:`modules`. Only the provided locations will be
updated. updated.
Args: Args:
modules (dict): A complete or partial dictionary of the modules modules (dict): A complete or partial dictionary of the module's
submodules. submodules.
strict (bool): If ``True`` checks that ``modules`` is a
subset of the child modules of this instance. Default: ``True``.
Returns: Returns:
The module instance after updating the submodules. The module instance after updating the submodules.
""" """
@@ -388,6 +400,14 @@ class Module(dict):
dst[k] = new_value dst[k] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(
f'Module does not have sub-module named "{k}".'
)
elif isinstance(modules, list): elif isinstance(modules, list):
for i in range(len(dst)): for i in range(len(dst)):
current_value = dst[i] current_value = dst[i]
@@ -396,6 +416,12 @@ class Module(dict):
dst[i] = new_value dst[i] = new_value
elif isinstance(current_value, (dict, list)): elif isinstance(current_value, (dict, list)):
apply(current_value, new_value) apply(current_value, new_value)
elif strict:
raise ValueError(
f"Received invalid type: {type(new_value).__name__}."
)
elif strict:
raise ValueError(f"Received invalid type: {type(modules).__name__}.")
apply(self, modules) apply(self, modules)
return self return self

View File

@@ -219,6 +219,46 @@ class TestBase(mlx_tests.MLXTestCase):
x = mx.zeros((3,)) x = mx.zeros((3,))
mx.grad(loss_fn)(model) mx.grad(loss_fn)(model)
def test_update(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent parameters
with self.assertRaises(ValueError):
updates = {"layers": [{"value": 0}]}
m.update(updates)
with self.assertRaises(ValueError):
updates = {"layers": ["hello"]}
m.update(updates)
# Wronge type
with self.assertRaises(ValueError):
updates = {"layers": [{"weight": "hi"}]}
m.update(updates)
def test_update_modules(self):
m = nn.Sequential(nn.Linear(3, 3), nn.Linear(3, 3))
# Updating non-existent modules should not be allowed by default
with self.assertRaises(ValueError):
m = m.update_modules({"values": [0, 1]})
# Update wrong types
with self.assertRaises(ValueError):
m = m.update_modules({"layers": [0, 1]})
class MyModule(nn.Module):
def __init__(self):
super().__init__()
self.test = mx.array(1.0)
self.list = [mx.array(1.0), mx.array(2.0)]
m = MyModule()
with self.assertRaises(ValueError):
m = m.update_modules({"test": "hi"})
with self.assertRaises(ValueError):
m = m.update_modules({"list": ["hi"]})
class TestLayers(mlx_tests.MLXTestCase): class TestLayers(mlx_tests.MLXTestCase):
def test_identity(self): def test_identity(self):